import os
import gradio as gr
from PIL import Image, ImageDraw
import re
import torch
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
import deepspeed

# 合并后的模型路径
merged_model_path = "./merged_model"
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} 

# 加载合并后的模型
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    merged_model_path,
    torch_dtype=torch.bfloat16,
    device_map=device_map,
    trust_remote_code=True
)

model = deepspeed.init_inference(
    model=model,      # Transformers模型
    mp_size=3,        # GPU数量
    dtype=torch.bfloat16, # 权重类型(bfloat16)
    replace_method="auto", # 让DS自动替换层
    replace_with_kernel_inject=True, # 使用kernel injector替换
)

# 加载处理器
processor = AutoProcessor.from_pretrained(merged_model_path, use_fast=True, trust_remote_code=True)

def get_model_prediction(image, description):
    image_path = "temp_image.jpg"
    image.save(image_path)

    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image", "image": image_path},
                {"type": "text", "text": f"Provide the bounding box for the following object: {description}"},
            ],
        }
    ]

    text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    image_inputs, _ = process_vision_info(messages)
    
    inputs = processor(text=[text], images=image_inputs, padding=True, return_tensors="pt").to(model.module.device)
    
    generated_ids = model.generate(**inputs, max_new_tokens=128)
    output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]

    bbox = list(map(int, re.findall(r'\d+', output_text)[-4:]))
    return bbox

def process(image, question):
    if image is None:
        return "请上传一张图片"

    bbox = get_model_prediction(image, question)
    original_image = image.convert("RGB")
    draw = ImageDraw.Draw(original_image)
    draw.rectangle([(bbox[0], bbox[1]), (bbox[2], bbox[3])], outline="red", width=3)
    return original_image

with gr.Blocks(css=".input-image {max-width: 48%; margin: 0.5%;}") as demo:
    gr.Markdown("# 图像对象检测")
    
    with gr.Row():
        image_input = gr.Image(type="pil", label="上传图片", elem_classes="input-image")
        output_image = gr.Image(label="检测结果", elem_classes="input-image")
    
    question_input = gr.Textbox(label="问题或描述", placeholder="请输入您想识别的物体名称...")
    
    btn = gr.Button("提交", variant="primary")

    btn.click(fn=process, inputs=[image_input, question_input], outputs=output_image)

demo.launch(share=True)